In [4]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import time
import datetime as dt
import os
from tqdm import tqdm
from sklearn.cluster import KMeans
from sklearn.cluster import AgglomerativeClustering
from dtw import dtw
from sklearn.metrics import pairwise_distances
PRE_NAME = "onemin_ohlc_"
BEGIN_TIME = "09:00:00"
END_TIME = "11:00:00"
NUM_CLUSTER = 50
def load_data():
    X = []
    Y = []
    # df = pd.read_csv(os.path.join('dataset', '2327', PRE_NAME+"20180612.csv"))
    # mask = (df.loc[:, "time"] >= BEGIN_TIME) & (df.loc[:, "time"] <= END_TIME)
    # front_df = df[mask].loc[:, "return"]
    # end_df = df[~mask].loc[:, "return"]
    # X.append(np.array(front_df))
    # Y.append(np.array(end_df))
    # """我是分隔線^^~"""
    # df = pd.read_csv(os.path.join('dataset', '2327', PRE_NAME+"20180613.csv"))
    # mask = (df.loc[:, "time"] >= BEGIN_TIME) & (df.loc[:, "time"] <= END_TIME)
    # front_df = df[mask].loc[:, "return"]
    # end_df = df[~mask].loc[:, "return"]
    # X.append(np.array(front_df))
    # Y.append(np.array(end_df))
    
    
    for sid in tqdm(os.listdir(os.path.join('dataset'))):
        for file in os.listdir(os.path.join('dataset', sid)):
            # print(file)
            df = pd.read_csv(os.path.join('dataset', sid, file))
            df = df.take(np.arange(0, len(df), 5))
            mask = (df.loc[:, "time"] >= BEGIN_TIME) & (df.loc[:, "time"] <= END_TIME)
            front_df = df[mask].loc[:, "return"]
            # print(front_df)
            # exit()
            end_df = df[~mask].loc[:, "return"]
            
            if len(front_df) == (121 // 5 +1):
                
                X.append(np.array(front_df))
                Y.append(np.array(end_df))
    
# print(end_df)
    # df = pd.read_csv(os.path.join('dataset', '2327', PRE_NAME+"20180612.csv"))
    # print(len(X))
    # print(len(X[0]))
    # X = np.array(X)
    # print(X.shape)
        
    # exit()
    return np.array(X), np.array(Y)

def dtw_d(X, Y):
    manhattan_distance = lambda x, y: np.abs(x - y)
    d, cost_matrix, acc_cost_matrix, path = dtw(X, Y, dist=manhattan_distance)
    return d

def dtw_affinity(X):
    return pairwise_distances(X, metric=dtw_d)

X, Y = load_data()
100%|██████████| 13/13 [00:10<00:00,  1.28it/s]
In [5]:
print(X.shape)
(1674, 25)
In [6]:
ac = AgglomerativeClustering(n_clusters = NUM_CLUSTER,
                             affinity = dtw_affinity,
                             linkage = 'complete')
X_label = ac.fit_predict(X)
In [7]:
X_label
Out[7]:
array([36, 31, 36, ...,  2,  0, 40], dtype=int64)
In [8]:
for n in range(NUM_CLUSTER):
    for i in range(len(X)):
    #     plt.subplot(10, 1, X_label[i]+1)
        if X_label[i] == n:
            plt.plot(X[i])
    plt.show()
In [9]:
def make_long_simple(y, cost):
    return y.max() - y[0] - cost
def make_short_simple(y, cost):
    return y[0] - y.min() - cost
def make_long_max_lost(y, cost):
    return y.min() - y[0] - cost
def make_short_max_lost(y, cost):
    return y[0] - y.max() - cost
In [10]:
profit_long_array = np.zeros(NUM_CLUSTER)
profit_short_array = np.zeros(NUM_CLUSTER)
lost_long_array = np.zeros(NUM_CLUSTER)
lost_short_array = np.zeros(NUM_CLUSTER)
num_long_array = np.zeros(NUM_CLUSTER).astype("int")
num_short_array = np.zeros(NUM_CLUSTER).astype("int")
In [ ]:
 
In [11]:
profit_long_array.shape
Out[11]:
(50,)
In [12]:
for i in range(len(X)):
    profit_long_array[X_label[i]] += make_long_simple(Y[i], 0.002)
    profit_short_array[X_label[i]] += make_short_simple(Y[i], 0.002)
    lost_long_array[X_label[i]] += make_long_max_lost(Y[i], 0.002)
    lost_short_array[X_label[i]] += make_short_max_lost(Y[i], 0.002)
    num_long_array[X_label[i]] += 1
    num_short_array[X_label[i]] += 1
print(profit_long_array)
print(profit_short_array)
print(lost_long_array)
print(lost_short_array)
print(num_long_array)
print(num_short_array)
[ 1.02659784  0.26222199  0.84341681  0.06873012  0.93624636  0.91051522
  0.33265282  1.34029719  0.05221971  0.13623851  0.10770505  0.16468621
  0.19649594  0.70413535  0.19124313  0.22870167  0.02132383  0.26309764
  0.09717009  0.47934743  0.0151934   0.08439834  0.42286481  0.31458964
  0.43833177  0.33806676  0.1841714   0.02381809  0.31187294  0.03450764
  0.10627655  0.24905575  0.89495091  0.18738182  0.1044975   0.06138358
  0.13479673  0.02154402  0.15981723  0.07186119  0.86201771  0.12690538
  0.56368575 -0.00204974  0.04505929  0.11860123  0.91774037 -0.00311816
  0.12778161  0.12274429]
[ 9.10683169e-01  3.25279656e-01  9.94621843e-01  3.15112014e-01
  1.79475795e+00  1.35416169e+00  3.01908161e-01  1.40320704e+00
  9.59187886e-02  2.23028531e-01  1.00140937e-01  3.71060880e-01
  8.88543066e-01  1.09312925e+00  1.40476475e-01  2.88859232e-01
  2.04678216e-01  2.25587160e-01  1.31234909e-01  6.32970461e-01
  9.25130642e-02  2.56455696e-02  4.16086832e-01  1.96872810e-01
  2.38060375e-01  1.05314020e+00  1.53195653e-01  3.32444630e-04
  2.85061385e-01  2.53955575e-02  2.28939849e-01  2.52441152e-01
  7.47871786e-01  8.93162309e-02  4.18886691e-02  1.82883463e-01
  2.09917355e-02  1.02126449e-02  1.91983003e-01  1.14861308e-01
  7.96171999e-01  2.45042104e-01  4.37645545e-01 -2.21718036e-03
  1.22464989e-01  9.47731532e-02  5.59418278e-01  1.10714697e-01
  2.47382966e-01  6.73199120e-02]
[-1.34668317 -0.39727966 -1.32262184 -0.38711201 -2.35475795 -1.92216169
 -0.42590816 -2.04720704 -0.13191879 -0.27902853 -0.16414094 -0.48706088
 -1.09654307 -1.48112925 -0.21247648 -0.37685923 -0.22867822 -0.30158716
 -0.15123491 -0.87297046 -0.12451306 -0.03764557 -0.58008683 -0.30887281
 -0.35406038 -1.2531402  -0.20119565 -0.01233244 -0.36906139 -0.06139556
 -0.28093985 -0.31244115 -1.01987179 -0.11331623 -0.05388867 -0.20288346
 -0.04099174 -0.01821264 -0.259983   -0.12686131 -1.180172   -0.3090421
 -0.58164555 -0.02178282 -0.13846499 -0.12677315 -0.90341828 -0.1227147
 -0.33938297 -0.09531991]
[-1.46259784 -0.33422199 -1.17141681 -0.14073012 -1.49624636 -1.47851522
 -0.45665282 -1.98429719 -0.08821971 -0.19223851 -0.17170505 -0.28068621
 -0.40449594 -1.09213535 -0.26324313 -0.31670167 -0.04532383 -0.33909764
 -0.11717009 -0.71934743 -0.0471934  -0.09639834 -0.58686481 -0.42658964
 -0.55433177 -0.53806676 -0.2321714  -0.03581809 -0.39587294 -0.07050764
 -0.15827655 -0.30905575 -1.16695091 -0.21138182 -0.1164975  -0.08138358
 -0.15479673 -0.02954402 -0.22781723 -0.08386119 -1.24601771 -0.19090538
 -0.70768575 -0.02195026 -0.06105929 -0.15060123 -1.26174037 -0.00888184
 -0.21978161 -0.15074429]
[109  18  82  18 140 142  31 161   9  14  16  29  52  97  18  22   6  19
   5  60   8   3  41  28  29  50  12   3  21   9  13  15  68   6   3   5
   5   2  17   3  96  16  36   6   4   8  86   3  23   7]
[109  18  82  18 140 142  31 161   9  14  16  29  52  97  18  22   6  19
   5  60   8   3  41  28  29  50  12   3  21   9  13  15  68   6   3   5
   5   2  17   3  96  16  36   6   4   8  86   3  23   7]
In [13]:
import pickle
pickle.dump(ac, open("ac_model2.pc", 'wb'))
In [ ]: